# -*- coding: utf-8 -*-
"""
Created on Mon Jul  5 09:55:08 2021

@author: os4875st

List of functions in fun.figs.py:
- rescale_from_uint16_to_uint8(img)
- show_img(I, title_text = '', cmap='viridis', sz=10, fsz = 15)
"""
import numpy as np
import matplotlib.pyplot as plt
import tifffile
import copy
import tiff_image_manipulation as ti
import inspect
import fun_figs as ff
import os
import cv2    
from skimage import exposure

def pi(pic):
    """
    Print the basic information about an image (array)
    """
    print(f'Type of the image: {type(pic)}, {pic.dtype}')
    print('Shape of the image: {}'.format(pic.shape))
    if len(pic.shape) > 1:
        print('Image Height: {}'.format(pic.shape[-2]))
    print('Image Width: {}'.format(pic.shape[-1]))
    print('Dimension of Image: {}'.format(pic.ndim))
    print('Image size: {}'.format(pic.size))
    print('Maximum RGB value in this image: {}'.format(pic.max()))
    print('Minimum RGB value in this image: {}'.format(pic.min()))
    print(f'mean = {np.round(np.mean(pic),5)}, median = {np.round(np.median(pic),5)}, std = {np.round(np.std(pic),5)}\n')
    print(f'(ignoring nan) mean = {np.round(np.nanmean(pic),5)}, median = {np.round(np.nanmedian(pic),5)}, std = {np.round(np.nanstd(pic),5)}, min = {np.nanmin(pic):.5f}, max = {np.nanmax(pic):.5f}\n')

def rescale_from_uint16_to_uint8(img):
    """Rescales a numpy array from uint16 to uint8
    """
    print(f'\tRescaling image {img.shape} from uint16 to uint8')
    if img.dtype != 'uint16':
        raise Exception(f'Data type should be be uint16 and not {img.dtype}...')
    return (img.astype(np.float32) * 255.999 / (2**16)).astype(np.uint8)
    
def rescale_to_uint8(img, lims=[]):
    """[Rescales a numpy array of uint16 or float32 to uint8
    lims sets the limits for re-scaling in case for low pixel values to save pixel value resolution.
    ]

    Args:
        img ([uint16 or float32]): [Array to be rescaled]

    Raises:
        Exception: [If the data type is incorrect]
    """
    if img.dtype == 'uint8':
        print('\tNo rescaling needed. Image was already uint8')
    else:        
        if img.dtype == 'uint16':
            if len(lims) > 0:
                print(f'\tRescaling image {img.shape} from {img.dtype} to uint8 with limits: {lims}')
                img =  (img.astype(np.float32) * 255.999 / (lims[1])).astype(np.uint8) 
            else:
                print(f'\tRescaling image {img.shape} from {img.dtype} to uint8')
                img =  (img.astype(np.float32) * 255.999 / (2**16)).astype(np.uint8)    
        elif img.dtype == 'float32' and np.max(img) > 2**16:
            print(f'\tRescaling image {img.shape} from {img.dtype} to uint8')
            img =  (img.astype(np.float64) * 255.999 / (2**32)).astype(np.uint8) 
        elif np.max(img) <= 255:
            print(f'\tNo rescaling performed as the max intensity is lower than 255')
            img = img.astype(np.uint8) 
        else:
            pi(img)
            raise Exception(f'Incorrect data type {img.dtype}...')
    return img

def merge_RG_channels_from_16_bit_nparray(I_org):
    """
    Merges two grayscale (16-bit) numpy arrays into a RGB array,
    where the two channels take up the red and green channels.
    Input argument I_org is 16-bit values in 2D a numpy array with 2 channels
    """
    I_uint8 = 255.999* I_org / (2**16)
    I_uint8 = I_uint8.astype(np.uint8)
    I_A1 = I_uint8[0,:,:]
    I_A2 = I_uint8[1,:,:]

    # I_RGB = np.dstack((I_aligned[0],I_aligned[1],np.zeros_like(I_aligned[0])))  # stacks 3 h x w arrays -> h x w x 3
    # I_RGB = (np.dstack((I_aligned[0],I_aligned[1],np.zeros_like(I_aligned[0]))) * 255.999) .astype(np.uint8)
    # I_RGB = (np.dstack((I_A1,I_A2,np.zeros_like(I_A1))) * 255.999) .astype(np.uint8)
    I_RGB = np.dstack((I_A1,I_A2,np.zeros_like(I_A1)))
    return I_RGB

def show_img(I, title_text = '', title='', cmap='viridis', sz=10, fsz = 15, vmin=-1, vmax=-1, xlim=[-1,-1], ylim=[-1,-1], show_axis=False):
    """
    Displays a 2D-array using matplotlib.pyplot with the axis off. 10x10 inches
    
    Colormaps I like are: "gray", "viridis", "inferno"
    For a selection of colormaps, see https://matplotlib.org/stable/tutorials/colors/colormaps.html
    """
    
    #width=10
    #height=10
    fig = plt.figure(figsize=(sz, sz))  # create a figure object
    ax = fig.add_subplot(1, 1, 1)  # create an axes object in the figure
    if not show_axis:
        ax.axis('off')
    if vmin >= 0 and vmax > 0:
        imgplot = ax.imshow(I, cmap=cmap, vmin=vmin, vmax=vmax) 
    else:
        imgplot = ax.imshow(I, cmap=cmap) 
    if len(title) > 0:
        plt.title(title, fontsize=fsz)
    elif len(title_text) > 0:
        plt.title(title_text, fontsize=fsz)
    if not ylim[-1] == -1:
        plt.ylim(ylim)
    if not xlim[-1] == -1:
        plt.xlim(xlim)
    plt.style.use('general') 
    plt.show()    

def simple_threshold(img, thresh_val=2):
    img_thresholded = copy.deepcopy(img)
    mask = img_thresholded <= thresh_val
    img_thresholded[mask] = 0
    return img_thresholded

def plot_sub_plots(I_list, titles, n_cols=4, sz=[5,5], fsz = 30, cmap='gray', add_aniso_cbar=-1, add_aniso_hsv_cbar=-1, aniso_hsv_cbar_range=[-0.1, 0.1], suptit='', display_range=[-1], save_imgs=False, file_name='', adjust_hspace=0.3):
    """
    Plots a (1D) list of figures with titles accordingly.

    Input arguments:
        I_list  -   List of images, [np.array]
        titles  -   List of titles, [string]
        n_cols  -   Number of columns, int
        sz      -   Figure Dimensions, width and then height [float32,float32]
        fsz     -   Title font size, float32
    """
    n = len(I_list) #Number of image rows
    n_rows = round(np.ceil(n / n_cols))
    if n_cols*n_rows > n:
        n_to_add = n_cols*n_rows - n
        for i in range(n_to_add):
            I_list.append(np.zeros_like(I_list[0]))
            titles.append('')
    
    fig,axs = plt.subplots(n_rows, n_cols, figsize=(sz[0],sz[1]))
    plt.subplots_adjust(hspace=adjust_hspace)
    plt.suptitle(suptit)
    for i, ax in enumerate(axs.ravel()):
        if not np.all(I_list[i] == 0):            
            if i == add_aniso_cbar:
                im = ax.imshow(I_list[i], cmap=cmap, vmin=-0.15, vmax=0.15)
            else:
                if display_range[0] >= 0:
                    im = ax.imshow(I_list[i], cmap=cmap, vmin=display_range[0], vmax=display_range[1])
                else:
                    im = ax.imshow(I_list[i], cmap=cmap)
            if i == add_aniso_hsv_cbar:
                ax.text(25, 30, 'r', color='k', fontsize=13)
                ax.text(65, 17, str(aniso_hsv_cbar_range[0]), color='r', fontsize=13)
                ax.text(2, 80, str(aniso_hsv_cbar_range[1]), color='lime', fontsize=13)
            ax.set_title(titles[i], fontsize=fsz)
        ax.axis('off')  
        ax.set_anchor('N') #Aligns the subplot to the top (N=north)
    if save_imgs:
        # file_name = f'subplots.png'
        # path_save = os.path.join(v.dir_pixelated_files, file_name)
        plt.savefig(file_name, bbox_inches="tight") #Save figure 
        print(f'Saving Histogram plot at {file_name}')
        currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
        os.startfile(currentdir)
    plt.show()

def mask_gray_img(image_gray, mask, color=[255, 0, 0], alpha=0.3):
    '''
    img: cv2 image
    mask: bool or np.where
    color: BGR triplet [_, _, _]. Default: [0, 255, 255] is yellow.
    alpha: float [0, 1]. 

    Ref: http://www.pyimagesearch.com/2016/03/07/transparent-overlays-with-opencv/
         https://gist.github.com/Puriney/8f89b43d96ddcaf0f560150d2ff8297e
    '''
    if image_gray.dtype != 'uint8':
        # image_gray = image_gray.astype(np.uint8)
        image_gray = rescale_to_uint8(image_gray)
    image_gray = exposure.rescale_intensity(image_gray)
    img = cv2.cvtColor(image_gray, cv2.COLOR_GRAY2BGR)
    out = img.copy()
    img_layer = img.copy()
    img_layer[mask] = color
    out = cv2.addWeighted(out, 1, img_layer, alpha, 0)
    return(out)

def blend_images(img1, img2):
    """[
    cv.addWeighted(src1, alpha, src2, beta, gamma[, dst[, dtype]])    
        src1: first input array,
        alpha: weight of the first array elements,
        src2: second input array of the same size and channel number as src1,
        beta: weight of the second array elements,
        gamma: scalar added to each sum,
        dst: output array that has the same size and number of channels as the input arrays,
        dtype: optional depth of the output array.


    ]

    Args:
        img1 ([type]): [description]
        img2 ([type]): [description]

    Returns:
        [type]: [description]
    """
    return cv2.addWeighted(img1, 0.5, img2, 0.7, 0)
def crop_border(img, b=5):
    """
    Crops an image from the border with the width b.
    """
    h_img = img.shape[-2]
    w_img = img.shape[-1]    
    if len(img.shape) == 3:
        cropped_img = img[:,b:h_img-b,b:w_img-b]
    else:
        cropped_img = img[b:h_img-b,b:w_img-b]        
    return cropped_img

def crop_image(img, crop_array_imageJ):
    """
    Crops a 3D-array according to the following (from ImageJ's homepage):
    Creates a rectangular selection, where x and y are
    the coordinates (in pixels) of the upper left corner
    of the selection. The origin (0,0) of the coordinate 
    system is the upper left corner of the image. 

    Format: x, y, width, height

    Args:
        img ([type]): [description]
        crop_array_imageJ ([type]): [description]

    Returns:
        [type]: [description]
    """
    if isinstance(crop_array_imageJ, list):
        crop_array_imageJ = np.array(crop_array_imageJ)
    if not np.all(crop_array_imageJ == 0):
        if len(img.shape) == 2:
            h_img = img.shape[0]
            w_img = img.shape[1]   
        elif len(img.shape) > 2:
            h_img = img.shape[1]
            w_img = img.shape[2]   
    
        x1 = crop_array_imageJ[0]
        width = crop_array_imageJ[2]
        x2 = x1 + width
        y1 = crop_array_imageJ[1]
        height = crop_array_imageJ[3]
        y2 = y1 + height
        print(f'\tCropping image according to: {crop_array_imageJ} x=({x1},{x2}), y = ({y1},{y2})')
        if height > h_img:
            raise Exception(f'The cropping rectangle height ({height}) cannot be larger than the image height ({h_img})!')
        if width > w_img:
            raise Exception(f'The cropping rectangle width ({width}) cannot be wider than the image width ({w_img})!')
        #print(f'\tImage w = {img.shape[2]}, h = {img.shape[1]}')
        #print(f'\tx = [{x1}, {x2}], y = [{y1}, {y2}]')    
        if len(img.shape) == 3:
            return img[:, y1:y2, x1:x2]
        elif len(img.shape) == 4:
            return img[:, y1:y2, x1:x2,:]
        else:
            return img[y1:y2, x1:x2]
    else:
        return img

def set_border_to_val(img, b=1, val=0):
    """
    Sets the image border to black with the border of width b.
    """
    img_blackborder = copy.deepcopy(img)
    h = img_blackborder.shape[0] #image height
    w = img_blackborder.shape[1] #image width
    img_blackborder[0:b,:] = val #bottom border
    img_blackborder[h-b:,:] = val #top border
    img_blackborder[:,0:b] = val #left border
    img_blackborder[:,w-b:] = val #right border
    return img_blackborder


def calc_percentiles(img, min=2, max=98):
    """
    Calculates the percentiles of a numpy array.
    Parameters:
        img: numpy array
        min: the lower percentile
        max: the higher percentile
    """
    min_val = np.percentile(img,min)
    max_val = np.percentile(img,max)
    return min_val, max_val

def read_tiff(file_path, frame_range=[0]):
    """Reads a tiff file using the library tifffile.

    Args:
        file_path ([type]): [description]
        frames (list, optional): [description]. Defaults to [0].

    Returns:
        [type]: [description]
    """
    #Read image
    #file_path_copy = file_path.replace('\\', '.')
    #file_path_copy = file_path_copy.replace('/', '.')
    #file_name0 = file_path_copy.split('.')[-2]

    #If the selected frames to read are all set to 0, read the full range in all dimensions. Otherwise, read only the selected range of dimensions.
    if not np.any(frame_range):
        img = tifffile.imread(file_path)
    else:
        #img = io.imread(file_path) #reads stacks as ZYX - change to something faster, that reads automatically to uint8?
        img = tifffile.imread(file_path, key=frame_range)
    img_metadata = ''
    return img

def add_small_color_bar_in_upper_right_corner_img(fig, ticks=[-0.15, 0, 0.15]):
    position=fig.add_axes([0.70,0.65,0.05,0.2])  # xmin, ymin, dx, dy
    colorbar_format = '% 1.2f'
    cbar = plt.colorbar(cax=position, ticks=[-0.15, 0, 0.15], orientation='vertical', format=colorbar_format)
    cbar.ax.tick_params(labelsize=13) 

import time
import math
from PIL import Image
def transform_img(img, angle=0, transform_mode='rot90', fillcolor=None):
    """
    img (np.array, uint8)
    Bilinear rotation of image
    angle: in degrees, counter-clockwise
    
    Note: quite slow as I rotate frame by frame in a for loop using PIL. Might be a faster way with Open CV. 
    However, I did not find a way to perform bilinear or bicubic interpolation with it.
    """
    if transform_mode == 'rot90':
        raise ValueError('This mode (rot90) has not been tested')
    elif transform_mode == 'flipHor':
        raise ValueError('This mode (flipHor) has not been tested')
    if not math.isnan(angle) or angle == 0:
        print(f'\tRotating the image {angle} degrees counter-clockwise')
        tic = time.perf_counter()
        #Rotating frame by frame...
        if len(img.shape) == 2:
            n_frames = 1
            img1 = Image.fromarray(img) #Convert to a PIL Image from a numpy array (must be 8-bit unsigned)
            img1 = img1.rotate(angle, resample=Image.BILINEAR, expand=True, translate=None, fillcolor=fillcolor) 
            img1 = np.array(img1) #Convert from a PIL Image to a numpy array
            img_rot = np.zeros([img1.shape[0], img1.shape[1]], dtype=np.uint8)
            img_rot[:,:] = img1               
        elif len(img.shape) == 3:
            n_frames = img.shape[0]
            img1 = Image.fromarray(img[0]) #Convert to a PIL Image from a numpy array (must be 8-bit unsigned)
            img1 = img1.rotate(angle, resample=Image.BILINEAR, expand=True, translate=None, fillcolor=fillcolor) 
            img1 = np.array(img1) #Convert from a PIL Image to a numpy array
            img_rot = np.zeros([n_frames, img1.shape[0], img1.shape[1]], dtype=np.uint8)
            for i in range(n_frames):            
                img1 = Image.fromarray(img[i]) #Convert to a PIL Image from a numpy array (must be 8-bit unsigned)
                img1 = img1.rotate(angle, resample=Image.BILINEAR, expand=True, translate=None, fillcolor=fillcolor) 
                img1 = np.array(img1) #Convert from a PIL Image to a numpy array
                img_rot[i] = img1
        elif len(img.shape) == 4:
            n_frames = img.shape[0]
            img1 = Image.fromarray(img[0]) #Convert to a PIL Image from a numpy array (must be 8-bit unsigned)
            img1 = img1.rotate(angle, resample=Image.BILINEAR, expand=True, translate=None, fillcolor=fillcolor) 
            img1 = np.array(img1) #Convert from a PIL Image to a numpy array
            img_rot = np.zeros([n_frames, img1.shape[0], img1.shape[1], img1.shape[2]], dtype=np.uint8)
            for i in range(n_frames):            
                img1 = Image.fromarray(img[i]) #Convert to a PIL Image from a numpy array (must be 8-bit unsigned)
                img1 = img1.rotate(angle, resample=Image.BILINEAR, expand=True, translate=None, fillcolor=fillcolor) 
                img1 = np.array(img1) #Convert from a PIL Image to a numpy array
                img_rot[i] = img1
        else:            
            raise ValueError(f'Wrong image shape: {img.shape}, len = {len(img.shape)}')
        toc = time.perf_counter()
        print(f'\t\tRotated {n_frames} frame(s) in {toc - tic:0.2f} seconds')
    return img_rot

def subtract_uneven_bg_from_illumination(img, file_path_bg, ignore_black_pixels=False, rescaling=True):    
    """[Subtracts the uneven background from the image based on a recording of the background without the sample. This recording is loaded from the file path 'file_path_bg']

    Args:
        img ([type]): [description]
        file_path_bg ([type]): [description]

    Returns:
        img ([np.uint8]): [Image with background subtracted]
    """
    print('\tSubtracting uneven illumination background')
    dtype_initial = img.dtype
    img_bg = ti.read_tif_file(file_path_bg).astype(np.float32)

    #If the field of view is smaller than full size, crop the background image accordingly (The FoV of the image is assumed to be centered on the camera sensor)
    if img_bg.shape != img.shape[1:]:
        print(f'\t\tReshaping the background image {img_bg.shape} so that it fits with the image {img.shape}')
        h_img_bg = img_bg.shape[0]
        w_img_bg = img_bg.shape[1]             
        w = img.shape[2]
        h = img.shape[1] 
        row1 = round(h_img_bg/2 - h/2)
        row2 = round(h_img_bg/2 + h/2)  
        new_h = len(range(row1,row2))
        if new_h != h:
            if new_h -1 == h:
                row2 = row2-1
            else:
                raise ValueError(f'Invalid new height for the bg image ({len(range(row1,row2))} != {h})')         
        col1 = round(w_img_bg/2 - w/2)
        col2 = round(w_img_bg/2 + w/2)    
        if len(range(col1,col2)) != w:
            raise ValueError('Invalid new weight for the bg image')                
        img_bg = img_bg[row1:row2,col1:col2]        
        if img_bg.shape != img.shape[1:]:
            ValueError('The dimensions of the bg image is not equal to that of the image after cropping')
        print(f'\t\tNew background image dimensions: {img_bg.shape}')
    fig, ax = plt.subplots(figsize=(15,4))
    plt.style.use('ppt') 
    I = img_bg.ravel()
    I = I[I != 0] 
    plt.hist(I,bins=256, color='r', range=(0,1.5*2**16))
    plt.title(f'Background Image Pixel Value Distribution\n(min: {np.nanmin(img_bg):.0f}, max: {np.nanmax(img_bg):.0f})')
    plt.xlabel('pixel value')
    plt.ylabel('count')  
    plt.axvline(2**16, linestyle = '--', label='uint16 pixel value limit', color='k')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.show()

    #Normalize and inverse the background image to create a mask.
    #The highest pixel values become 1 and lower values higher than 1
    max_val = np.max(img_bg)
    mask_bg = max_val/img_bg         

    if ignore_black_pixels:
        #Set those outside the two image sides to NaN    
        lim_mask_bg = 1.7
        perc_97 = np.percentile(mask_bg,97)    
        if perc_97 < lim_mask_bg:
            lim_mask_bg = perc_97
        print(f'All pixels values below {lim_mask_bg}/ max value is set to NaN (97th percentile: {perc_97})')
        mask_black = mask_bg > lim_mask_bg
        mask_bg[mask_black] = np.nan

        ff.show_img(mask_black, title_text = 'Black Mask (set to NaN)', cmap='gray', sz=5, fsz = 15)
        
        fig, ax = plt.subplots(figsize=(15,4))
        plt.style.use('ppt') 
        I = mask_bg.ravel()
        I = I[I != 0] 
        plt.hist(I,bins=256, color='r', range=(0,lim_mask_bg))
        plt.title(f'Normalized Background Image Pixel Value Distribution\n(min: {np.nanmin(mask_bg):.0f}, max: {np.nanmax(mask_bg):.0f} [97th percentile])')
        plt.xlabel('pixel value')
        plt.ylabel('count')  
        # plt.legend(loc='upper right')
        plt.show()

    fig, ax = plt.subplots(figsize=(15,4))
    plt.style.use('ppt') 
    I = img.ravel()
    I = I[I != 0] 
    plt.hist(I,bins=256, color='b', range=(0,1.5*2**16))
    plt.title(f'Image Pixel Value Distribution\nbefore background subtraction\n(min: {np.nanmin(img):.0f}, max: {np.nanmax(img):.0f})')
    plt.xlabel('pixel value')
    plt.ylabel('count')  
    plt.axvline(2**16, linestyle = '--', label='uint16 pixel value limit', color='k')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.show()

    print(f'\tSubtracting the background: (min: {np.nanmin(mask_bg):.0f}, max: {np.nanmax(mask_bg):.0f})')

    #Multiply the image with the inverse background illumination
    img = img * mask_bg

    fig, ax = plt.subplots(figsize=(15,4))
    plt.style.use('ppt') 
    I = img.ravel()
    I = I[I != 0] 
    plt.hist(I,bins=256, color='b', range=(0,1.5*2**16))
    plt.title(f'Image Pixel Value Distribution\nafter background subtraction\n(min: {np.nanmin(img):.0f}, max: {np.nanmax(img):.0f})')
    plt.xlabel('pixel value')
    plt.ylabel('count')  
    plt.axvline(2**16, linestyle = '--', label='uint16 pixel value limit', color='k')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.show()

    if dtype_initial == 'uint16':
        #Set pixels with NaN to 0
        img_mask_nan = np.isnan(img)
        # ff.show_img(img_mask_nan[0], title_text = 'NaN mask', cmap='gray', sz=5, fsz = 15)
        img[img_mask_nan] = 0

        #If the image is larger for some pixels than the maximum uint16 value (2**16), 
        # rescale the image so that the maximum is 2**16 and convert it to uint16
        # I_max =  np.percentile(img,99.9)
        if rescaling:
            I_max =  np.max(img)
            factor = 2**16 / I_max
            print(f'99.95 percentile: {I_max:.0f}\nMultiplying image with factor {factor:.3f}')
            img =  img * factor
            img = img.astype(dtype_initial)

        fig, ax = plt.subplots(figsize=(15,4))
        plt.style.use('ppt') 
        I = img.ravel()
        I = I[I != 0] 
        plt.hist(I,bins=256, color='b', range=(0,1.5*2**16))
        plt.title(f'Image Pixel Value Distribution\nafter final rescaling\n(min: {np.nanmin(img):.0f}, max: {np.nanmax(img):.0f})')
        plt.xlabel('pixel value')
        plt.ylabel('count')  
        plt.axvline(2**16, linestyle = '--', label='uint16 pixel value limit', color='k')
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.show()

    return img

def save_z_stack(img, mode='med', file_name0 = '',  txt_extra = '', folder=''):
    if mode == 'med':
        img_z_stack = np.median(img,axis=0)
    elif mode == 'mean':
        img_z_stack = np.mean(img,axis=0)
    elif mode == 'max':
        img_z_stack = np.max(img,axis=0)    
    file_path = os.path.join(folder, mode+txt_extra+file_name0+'.tiff')
    ti.write_tif_file(file_path, img_z_stack)

def gen_p_string(list_v):
    p_list = np.array([v.p for v in list_v])
    p_str=''
    for v in list_v:
        p = v.p
        p = p.replace('.',',')
        p_str=p_str+''+p
    return p_str

def save_plot(list_v, v, str='', dir_save=''):
    p_str = gen_p_string(list_v)
    file_name = f'{str}_{p_str}.png'
    if len(dir_save) > 0:
        dir_plots = dir_save
    else:
        dir_plots = os.path.join(v.dir_exp, 'plots')
    if not os.path.exists(dir_plots):
        os.mkdir(dir_plots)  
    path_save = os.path.join(dir_plots, file_name)
    plt.savefig(path_save, bbox_inches="tight") #Save figure
    os.startfile(dir_plots) 
    print(f'Saving plot at {path_save}')
